import os
import re
import json
import tqdm
import torch
import logging
import argparse
import numpy as np


import sys
sys.path.append(".")

import nlp

from overrides import overrides
from torch.nn import CrossEntropyLoss
from sklearn.metrics import accuracy_score

from cmlm_model import CMLModel

from transformers import RobertaTokenizer

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)


class InstanceReader(object):
    def to_uniform_fields(self, fields):
        pass

    def fields_to_instance(self, fields):
        pass


class CopaInstanceReader(InstanceReader):
    """
    Reads the COPA dataset into a unified format with context, question, label, and choices.
    """
    @overrides
    def to_uniform_fields(self, fields):
        context = fields['premise']
        if not context.endswith("."):
            context += "."

        question = {"cause": "The cause for it was that", "effect": "As a result,"}[fields['question']]
        label = fields.get('label', None)
        choices = [fields['choice1'], fields['choice2']]
        return context, question, label, choices

    @overrides
    def fields_to_instance(self, fields):
        context, question, label, choices = self.to_uniform_fields(fields)
        context_with_choices = [f"{context} {question} {choice[0].lower() + choice[1:]}" for choice in choices]
        return context, question, label, choices, context_with_choices


class PiqaInstanceReader(InstanceReader):
    """
    Reads the PIQA dataset into a unified format with context, question, label, and choices.
    """
    @overrides
    def to_uniform_fields(self, fields):
        context = ""
        question = fields["goal"]
        label = fields.get('label', None)
        choices = [fields["sol1"], fields["sol2"]]
        return context, question, label, choices

    @overrides
    def fields_to_instance(self, fields):
        context, question, label, choices = self.to_uniform_fields(fields)
        context_with_choices = [f"{question} {choice[0].lower() + choice[1:]}" for choice in choices]
        return context, question, label, choices, context_with_choices


class SocialIQAInstanceReader(InstanceReader):
    """
    Reads the SocialIQa dataset into a unified format with context, question, label, and choices.
    """
    def __init__(self):
        super(SocialIQAInstanceReader).__init__()
        self.QUESTION_TO_ANSWER_PREFIX = {
              "What will (.*) want to do next?": r"As a result, [SUBJ] wanted to",
              "What will (.*) want to do after?": r"As a result, [SUBJ] wanted to",
              "How would (.*) feel afterwards?": r"As a result, [SUBJ] felt",
              "How would (.*) feel as a result?": r"As a result, [SUBJ] felt",
              "What will (.*) do next?": r"[SUBJ] then",
              "How would (.*) feel after?": r"[SUBJ] then",
              "How would you describe (.*)?": r"[SUBJ] is seen as",
              "What kind of person is (.*)?": r"[SUBJ] is seen as",
              "How would you describe (.*) as a person?": r"[SUBJ] is seen as",
              "Why did (.*) do that?": r"Before, [SUBJ] wanted",
              "Why did (.*) do this?": r"Before, [SUBJ] wanted",
              "Why did (.*) want to do this?": r"Before, [SUBJ] wanted",
              "What does (.*) need to do beforehand?": r"Before, [SUBJ] needed to",
              "What does (.*) need to do before?": r"Before, [SUBJ] needed to",
              "What does (.*) need to do before this?": r"Before, [SUBJ] needed to",
              "What did (.*) need to do before this?": r"Before, [SUBJ] needed to",
              "What will happen to (.*)?": r"[SUBJ] then",
              "What will happen to (.*) next?": r"[SUBJ] then"
        }

    @overrides
    def to_uniform_fields(self, fields):
        context = fields['context']
        if not context.endswith("."):
            context += "."

        question = fields['question']
        label = fields['correct']
        choices = [fields['answerA'], fields['answerB'], fields['answerC']]
        choices = [c + "." if not c.endswith(".") else c for c in choices]
        label = ord(label) - 65
        return context, question, label, choices

    @overrides
    def fields_to_instance(self, fields):
        context, question, label, choices = self.to_uniform_fields(fields)

        answer_prefix = ""
        for template, ans_prefix in self.QUESTION_TO_ANSWER_PREFIX.items():
            m = re.match(template, question)
            if m is not None:
                answer_prefix = ans_prefix.replace("[SUBJ]", m.group(1))
                break

        if answer_prefix == "":
            answer_prefix = question.replace("?", "is")

        choices = [
            " ".join((answer_prefix, choice[0].lower() + choice[1:])).replace(
                "?", "").replace("wanted to wanted to", "wanted to").replace(
                "needed to needed to", "needed to").replace("to to", "to") for choice in choices]

        context_with_choices = [f"{context} {choice}" for choice in choices]
        return context, question, label, choices, context_with_choices


class WinograndeInstanceReader(InstanceReader):
    """
    Reads the WinoGrande dataset into a unified format with context, question, label, and choices.
    """
    @overrides
    def to_uniform_fields(self, fields):
        context = fields['sentence']
        if not context.endswith("."):
            context += "."

        label = fields['answer']
        choices = [fields['option1'], fields['option2']]
        label = int(label) - 1
        question = ''
        return context, question, label, choices

    @overrides
    def fields_to_instance(self, fields):
        context, question, label, choices = self.to_uniform_fields(fields)
        context_with_choices = [context.replace("_", choice) for choice in choices]
        return context, question, label, choices, context_with_choices


class CommonsenseqaInstanceReader(InstanceReader):
    """
    Reads the CommonsenseQA dataset into a unified format with context, question, label, and choices.
    """
    @overrides
    def to_uniform_fields(self, fields):
        context = fields["contexts"]

        question = fields['question']['stem']
        label = ['A','B','C','D','E'].index(fields['answerKey']) if "answerKey" in fields else None
        choices = [c['text'] for c in fields['question']['choices']]
        return context, question, label, choices

    @overrides
    def fields_to_instance(self, fields):
        context, question, label, choices = self.to_uniform_fields(fields)
        context_with_choices = [f"{question} {choice[0].lower() + choice[1:]}" for choice in choices]
        return context, question, label, choices, context_with_choices


class MCTACOInstanceReader(InstanceReader):
    """
    Reads the MCTaco dataset into a unified format with context, question, label, and choices.
    """
    @overrides
    def to_uniform_fields(self, fields):
        context = fields['context']
        question = fields['question']
        choices = fields["choices"]
        label = fields.get("label", None)
        return context, question, label, choices

    @overrides
    def fields_to_instance(self, fields):
        context, question, label, choices = self.to_uniform_fields(fields)
        context_with_choices = [f"{context} {question} {choice[0].lower() + choice[1:]}" for choice in choices]
        return context, question, label, choices, context_with_choices
    
class SCIInstanceReader(InstanceReader):
    @overrides
    def to_uniform_fields(self, fields):
        context = fields["context"]
        question = fields['question']["stem"]
        choices = fields['question']["choices"]

        try:
            label = ['A','B','C','D','E','F','G','H'].index(fields['answerKey']) if "answerKey" in fields else None
        except:
            label = ['1','2','3','4'].index(fields['answerKey']) if "answerKey" in fields else None
        return context, question, label, choices

    @overrides
    def fields_to_instance(self, fields):
        context, question, label, choices = self.to_uniform_fields(fields)
        choices = [choice['text'] for choice in choices]
        if len(choices)<8:
            choices = choices+["dont know"]*(8-len(choices))
            assert len(choices)==8
        context_with_choices = [f"{context} {question} {choice[0].lower() + choice[1:]}" for choice in choices]
        return context, question, label, choices, context_with_choices
    
class AnliInstanceReader(InstanceReader):
    @overrides
    def to_uniform_fields(self, fields):
        context = fields['premises'][0][0]
        question = fields['premises'][0][1]
        choices = fields["choices"]
        label = fields["gold_label"]
        return context, question, label, choices

    @overrides
    def fields_to_instance(self, fields):
        context, question, label, choices = self.to_uniform_fields(fields)
        context_with_choices = [f"{context} {choice[0].lower() + choice[1:]} {question}" for choice in choices]
        return context, question, label, choices, context_with_choices

def pad_tokens(input_ids,block_size,tokenizer):
    if len(input_ids) >= block_size-2:
        input_ids = input_ids[0:block_size-2]
    elif len(input_ids) < block_size-2:
        input_ids = input_ids+[tokenizer.pad_token_id]*(block_size-2-len(input_ids))
    return input_ids  

def generate_tokens(context,question,choices,tokenizer, is_anli=False):
    max_len = 75
    ctxt = context
    ctokens = tokenizer.tokenize(ctxt)
    qtokens = tokenizer.tokenize(question)
    
    cmask = ["<mask>"]*len(ctokens)
    qmask = ["<mask>"]*len(qtokens)
   
    cmasked_list=[]
    qmasked_list=[]
    chmasked_list=[]
    target_list=[]

    for choice in choices:
        ptokens = tokenizer.tokenize(choice)
        pmask = ["<mask>"]*len(ptokens)
        cmasked_tokens = tokenizer.convert_tokens_to_ids(cmask + ["."] + qtokens + ptokens)
        qmasked_tokens = tokenizer.convert_tokens_to_ids(ctokens + ["."] + qmask + ptokens)
        chmasked_tokens = tokenizer.convert_tokens_to_ids(ctokens + ["."] + qtokens + pmask)
        target = tokenizer.convert_tokens_to_ids(ctokens + ["."] + qtokens + ptokens)
        cmasked_list.append(pad_tokens(cmasked_tokens,80,tokenizer))
        qmasked_list.append(pad_tokens(qmasked_tokens,80,tokenizer))
        chmasked_list.append(pad_tokens(chmasked_tokens,80,tokenizer))
        target_list.append(pad_tokens(target,80,tokenizer))

    
    return torch.tensor(cmasked_list),torch.tensor(qmasked_list),torch.tensor(chmasked_list),torch.tensor(target_list)                    
                                
def create_tensors(context,question,choices,tokenizer, is_anli=False):
    label_tokens = []
    context_masked = []
    question_masked = []
    answer_masked = []
    if len(context)!=0:
            context_tokens = tokenizer.encode(context)
    else:
            context_tokens = []
    question_tokens = tokenizer.encode(question)
    context_mask = tokenizer.convert_tokens_to_ids([tokenizer.mask_token]*len(context_tokens))
    question_mask = tokenizer.convert_tokens_to_ids([tokenizer.mask_token]*len(question_tokens))

    choice_tokens = []
    choice_masks = []
    for choice in choices:
        choice_token = tokenizer.encode(choice)
        choice_mask = tokenizer.convert_tokens_to_ids([tokenizer.mask_token]*len(choice_token))
        choice_tokens.append(choice_token)
        choice_masks.append(choice_mask)
    
    max_length = max([len(x) for x in choice_tokens])
    for choice_token, choice_mask in zip(choice_tokens,choice_masks):
        choice_token= pad_tokens(choice_token,max_length,tokenizer)
        choice_mask=pad_tokens(choice_mask,max_length,tokenizer)
        label_token = context_tokens+question_tokens+choice_token if not is_anli else context_tokens+choice_token+question_tokens
        context_masked_tokens = context_mask+question_tokens+choice_token if not is_anli else context_mask+choice_token+question_tokens
        question_masked_tokens = context_tokens+question_mask+choice_token if not is_anli else context_tokens+choice_token+question_mask
        choice_masked_tokens = context_tokens+question_tokens+choice_mask if not is_anli else context_tokens+choice_mask+question_tokens
        label_tokens.append(label_token)
        context_masked.append(context_masked_tokens)
        question_masked.append(question_masked_tokens)
        answer_masked.append(choice_masked_tokens)
    return torch.tensor(label_tokens), torch.tensor(context_masked), torch.tensor(question_masked), torch.tensor(answer_masked)


INSTANCE_READERS = {"copa": CopaInstanceReader,
                    "socialiqa": SocialIQAInstanceReader,
                    "winogrande": WinograndeInstanceReader,
                    "piqa": PiqaInstanceReader,
                    "commonsenseqa":CommonsenseqaInstanceReader,
                    "mctaco":MCTACOInstanceReader,
                    "sci":SCIInstanceReader,
                    "anli": AnliInstanceReader}

                            

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--lm", default="roberta-base", type=str, required=False, help="language model to use")
    parser.add_argument("--dataset_train_file", default=None, type=str, required=False, help="Jsonl file")
    parser.add_argument("--dataset_val_file", default=None, type=str, required=False, help="Jsonl file")
    parser.add_argument("--out_dir", default=None, type=str, required=False, help="Out directory for the predictions")
    parser.add_argument("--device", default=1, type=int, required=False, help="GPU device")
    parser.add_argument("--type", default=None, type=str, required=False, help="Dataset type")

    args = parser.parse_args()
    logger.info(args)

    path = args.lm
    start=100
    end=8600
    args.out_dir = "/data/datasets/ktl"
    out_file = os.path.join(args.out_dir, f"{args.type}_acc.jsonl")
    with open(out_file, "w") as f_out:
        fpath = path
        args.lm = fpath
        # Load the language model
        # device = torch.device(f'cuda:{args.device}') if args.device >= 0 else torch.device("cpu")
        model, tokenizer = init_model(args.lm, 0)
        try:
            main_ex(args,model,tokenizer,dataset_file=args.dataset_val_file,stype="val",device=0,f_out=f_out)
        except:
            print("Unexpected error:", sys.exc_info()[0])
            raise

def score_min(nscore,old_score):
    nextscore=[]
    # print(nscore,old_score)
    for ns,os in zip(nscore,old_score[0]):
        nextscore.append(min(ns,os))
    return nextscore

def score_add(nscore,old_score):
    nextscore=[]
    # print(nscore,old_score)
    for ns,os in zip(nscore,old_score[0]):
        nextscore.append(ns+os)
    return nextscore

def score_mul(nscore,old_score):
    nextscore=[]
    # print(nscore,old_score)
    for ns,os in zip(nscore,old_score[0]):
        nextscore.append(ns*os)
    return nextscore

def main_ex(args,model,tokenizer,dataset_file,stype,device,f_out):
    # Load the dataset
    instance_reader = INSTANCE_READERS[args.type]()
    set_name = os.path.basename(dataset_file).replace(".jsonl", "")
   
    gold = []
    predictions_f1 = []
    predictions_f2 = []
    predictions_f3 = []
    predictions_f4 = []
    predictions_f5 = []
    predictions_f6 = []
    predictions_f7 = []

    pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0

    # Predict instances
    
    with open(dataset_file) as f_in:
        for line in tqdm.tqdm(f_in):
            fields = json.loads(line.strip())
            context, question, label, choices, context_with_choices = \
                instance_reader.fields_to_instance(fields)

            gold.append(label)

            # Tokenize and pad
            # tokenized = [tokenizer.encode(text) for text in context_with_choices]
            # max_length = max([len(text) for text in tokenized])
            # tokenized = [text + [pad_token_id] * (max_length - len(text)) for text in tokenized]
            # tokenized = torch.tensor(tokenized).long().to(device)
            f1,f2,f3,f4,f5,f6,f7=[[1]*len(choices)]*7

            sf = score_min

            for ctxt in context[0:3]:
                cmasked, qmasked, amasked, target = generate_tokens(ctxt,question,choices,tokenizer, is_anli=False)
                target, cmasked, qmasked, amasked = target.cuda(), cmasked.cuda(), qmasked.cuda(), amasked.cuda()
                preds_f1,preds_f2,preds_f3,preds_f4,preds_f5,preds_f6,preds_f7 = model.score(cmasked,qmasked,amasked,target,None)
                f1,f2,f3,f4,f5,f6,f7=sf(f1,preds_f1.cpu().numpy()),sf(f2,preds_f2.cpu().numpy()),sf(f3,preds_f3.cpu().numpy()), \
                    sf(f4,preds_f4.cpu().numpy()),sf(f5,preds_f5.cpu().numpy()),sf(f6,preds_f6.cpu().numpy()),sf(f7,preds_f7.cpu().numpy())
            prediction_f1 = int(np.argmin(f1))
            prediction_f2 = int(np.argmin(f2))
            prediction_f3 = int(np.argmin(f3))
            prediction_f4 = int(np.argmin(f4))
            prediction_f5 = int(np.argmin(f5))
            prediction_f6 = int(np.argmin(f6))
            prediction_f7 = int(np.argmin(f7))


            # fields["prediction"] = prediction
            predictions_f1.append(prediction_f1)
            predictions_f2.append(prediction_f2)
            predictions_f3.append(prediction_f3)
            predictions_f4.append(prediction_f4)
            predictions_f5.append(prediction_f5)
            predictions_f6.append(prediction_f6)
            predictions_f7.append(prediction_f7)
            # f_out.write(json.dumps(fields) + "\n")

    # Don't report accuracy if we don't have the labels
    if None not in gold:
        print("\nAccuracies:")
        accuracy_f1(gold,predictions_f1,"f1",stype,f_out)
        accuracy_f1(gold,predictions_f2,"f2",stype,f_out)
        accuracy_f1(gold,predictions_f3,"f3",stype,f_out)
        accuracy_f1(gold,predictions_f4,"f4",stype,f_out)
        accuracy_f1(gold,predictions_f5,"f5",stype,f_out)
        accuracy_f1(gold,predictions_f6,"f6",stype,f_out)
        accuracy_f1(gold,predictions_f7,"f7",stype,f_out)

def accuracy_f1(gold,predictions,ftype,stype,f_out):
    accuracy = accuracy_score(gold, predictions)
    print(f"Accuracy: {accuracy:.3f}",flush=True)
    f_out.write(f"Accuracy: {stype} {ftype} {accuracy:.3f}\n")

def get_lm_score(model, batch):
    """
    Get the cross entropy loss of the texts in batch using the langage model
    """
    # Batch: [num_choices, max_length]
    with torch.no_grad():
        num_choices, max_length = batch.shape
        shift_labels = batch[..., 1:].contiguous().view(-1)
        lm_logits = model(batch)[0]
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_logits = shift_logits.view(-1, shift_logits.size(-1))
        loss_fct = CrossEntropyLoss(reduction="none")
        loss = loss_fct(shift_logits, shift_labels)
        loss = loss.view(num_choices, -1).mean(1).cpu().numpy()

    return loss


def init_model(model_name: str,
               device: torch.device):
    """
    Initialize a pre-trained LM
    :param model_name: from MODEL_CLASSES
    :param device: CUDA / CPU device
    :return: the model and tokenizer
    """
    logger.info(f'Initializing {model_name}')
    tokenizer = RobertaTokenizer.from_pretrained("roberta-large")
    model = CMLModel.from_pretrained(model_name)
    model.cuda()
    model.eval()
    return model, tokenizer


if __name__ == '__main__':
    main()
